# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections.abc
import functools
import random

import cv2
import numpy as np


def wrap_keepdims(func):
    """Wraper to keep the dimension of input images unchanged"""

    @functools.wraps(func)
    def wrapper(image, *args, **kwargs):
        if len(image.shape) != 3:
            raise ValueError(
                "image must have 3 dims, but got {} dims".format(len(image.shape))
            )
        ret = func(image, *args, **kwargs)
        if len(ret.shape) == 2:
            ret = ret[:, :, np.newaxis]
        return ret

    return wrapper


@wrap_keepdims
def to_gray(image):
    r"""
    Change BGR format image's color space to gray

    :param image: Input BGR format image, with (H, W, C) shape
    :return: Gray format image, with (H, W, C) shape
    """
    return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)


@wrap_keepdims
def to_bgr(image):
    r"""
    Change gray format image's color space to BGR

    :param image: input Gray format image, with (H, W, C) shape
    :return: BGR format image, with (H, W, C) shape
    """
    return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)


@wrap_keepdims
def pad(input, size, value):
    r"""
    Pad input data with *value* and given *size*

    :param input: Input data, with (H, W, C) shape
    :param size: Padding size of input data, it could be integer or sequence.
        If it's an integer, the input data will be padded in four directions.
        If it's a sequence contains two integer, the bottom and right side
        of input data will be padded.
        If it's a sequence contains four integer, the top, bottom, left, right
        side of input data will be padded with given size.
    :param value: Padding value of data, could be a sequence of int or float.
        if it's float value, the dtype of image will be casted to float32 also.
    :return: Padded image
    """
    if isinstance(size, int):
        size = (size, size, size, size)
    elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
        size = (0, size[0], 0, size[1])
    if np.array(value).dtype == float:
        input = input.astype(np.float32)
    return cv2.copyMakeBorder(input, *size, cv2.BORDER_CONSTANT, value=value)


@wrap_keepdims
def flip(image, flipCode):
    r"""
    Accordding to the flipCode (the type of flip), flip the input image

    :param image: Input image, with (H, W, C) shape
    :param flipCode: code that indicates the type of flip.
        1 : Flip horizontally
        0 : Flip vertically
        -1 : Flip horizontally and vertically
    :return: BGR format image, with (H, W, C) shape
    """
    return cv2.flip(image, flipCode=flipCode)


@wrap_keepdims
def resize(input, size, interpolation=cv2.INTER_LINEAR):
    r"""
    resize the input data to given size

    :param input: Input data, could be image or masks, with (H, W, C) shape
    :param size: Target size of input data, with (height, width) shape.
    :param interpolation: Interpolation method.
    :return: Resized data, with (H, W, C) shape
    """
    if len(size) != 2:
        raise ValueError("resize needs (h, w), but got {}".format(size))

    if isinstance(interpolation, collections.abc.Sequence):
        interpolation = random.choice(interpolation)
    return cv2.resize(input, size[::-1], interpolation=interpolation)
